# This script makes a county-level
# data set with crop-specific decadal
# productivity shocks ("LocalEE"), 
# and the leave-one-out version
# of the same ("CropMixEE")


# Jacob Moscona and Karthik Sastry
# This version: last edited on September 28, 2022

import pandas as pd
import numpy as np
import os


##
## Options
## 

## These are mostly parallel with the same options
## in "make_crop_shocks.py"

## Directory for both raw data, and data outputs
wd = '/home/karthik/Dropbox (MIT)/climate_crops/QJE_Submission/Replication/Data/'
os.chdir(wd)

## What areas to use
## (Default, if both are false, is to use 1959 Census of Ag)
use_2012_area = True # 2012 census of agriculture
use_1955_area = False # 1955 census of agriculture

## Calculate shocks in the future using climatic projections?
use_projection = True
## If true, the following options set...
proj_year = 2090 # what future year to consider
path = '85' # what emissions pathway

## Whether to also calculate a "leave own state out"
## version (takes extra time)
do_state = False

## Apply the options above, and
## determine the name of the relevant 
## crop-shocks and crop-by-county files
## and make a filename for the output
name_crops = 'crop_shocks'
name_counties = 'crop_x_county_shocks'
name_output = 'county_shocks'

if use_projection:
    name_crops += '_' + path + '_' + str(proj_year)
    name_counties += '_' + path + '_' + str(proj_year)
    name_output += '_' + path + '_' + str(proj_year)
    year_range = [2010,proj_year] # Calculate these decades
else:
    year_range = range(1950,2020,10) # 1950, 1960 ... 2010
    
if use_2012_area:
    name_crops += '_2012a'
    name_counties += '_2012a'
    name_output += '_2012a'
    colname = 'area2012'
elif use_1955_area:
    name_crops += '_1955a'
    name_counties += '_1955a'
    name_output += '_2012a'
    colname = 'area_1955'
else:
    colname = 'area'





##
## Load Datasets 
##

crop_dataset = pd.read_csv(name_crops + '.csv').set_index('crop_id') # Crop-level dataset

county_x_crop_dataset = pd.read_csv(name_counties + '.csv') # County dataset with crop specific info
county_dataset = pd.read_csv('county_areas_raw.csv') # County dataset
county_dataset = pd.merge(county_dataset,county_x_crop_dataset) # Merge areas with crop GDD
county_dataset = county_dataset.set_index(['STATE_FIPS','CNTY_FIPS'])


##
## Prepare for leave-state-out calculation
##

states = pd.unique(county_dataset.index.get_level_values('STATE_FIPS'))
state_info = {}
croplist = crop_dataset.index.values
area_names = [colname + '_' + str(int(e)) for e in croplist] # Names of the area columns
ncrop = len(croplist) # Number of crops
EPS = 1e-2 # If the fraction of production for a crop  outside the state is < EPS, then ignore this crop
lta = crop_dataset['log_total_area'].values
ta = np.exp(lta) # Total area per crop

## Make the state-specific variables, that we can then subtract
## from the overall "other" variable to get a leave-state-out variable

if do_state:
    for state in states:
                
        # Area data
        state_area_data = county_dataset.loc[pd.IndexSlice[state,:],area_names]
    
        # Sum of area per crop
        state_area_totals = state_area_data.sum(0)
        state_area_shares = state_area_totals.values / ta
        state_area_shares[state_area_shares > 1-EPS] = 1 ## close enough to 1 to round
    
        state_gdd_avg = {}
        state_days_avg = {}
    
        for year in year_range:
            
            gdd_columns = [str(cx) + '_' + 'gddHot_' + str(year) for cx in croplist]
            days_columns = [str(cx) + '_' + 'daysHot_' + str(year) for cx in croplist]
    
            # GDD data in the whole state
            state_gdd_data = county_dataset.loc[pd.IndexSlice[state,:],gdd_columns]
            avg = np.nansum(state_gdd_data.values * state_area_data.values,0)
            avg = avg/state_area_totals
            state_gdd_avg[year] = avg
    
            # Days data in the whole state
            state_days_data = county_dataset.loc[pd.IndexSlice[state,:],days_columns]
            avg = np.nansum(state_days_data.values * state_area_data.values,0)
            avg = avg/state_area_totals
            state_days_avg[year] = avg
            
        state_info[state] = (state_area_totals,state_area_shares,state_gdd_avg,state_days_avg)

##
## Main calculation
##

# Function that calculates LocalEE, CropMixEE per county, identified by its fips index
def create_row(index):
    county_area = county_dataset.loc[index,area_names]
    
    if np.sum(county_area > 0):
        county_area_weights = county_area / np.sum(county_area) # adds up to 1
    else:
        county_area_weights = 0.0 * county_area # If all crops are 0
    
    if do_state: # For leave state out calculation
        state_id = index[0] # fips number of state for this row
        this_state_info = state_info[state_id]

    row = pd.Series(name = index, dtype='float')

    for year in year_range: # Loop over decades
        
        for feature in ['gdd','days']: # Repeat for the GDD and days above calculation
        
            f_columns = [str(cx) + '_' + feature + 'Hot_' + str(year) for cx in croplist]
    
            # Ecocrop exposure of crops, weighted by area
            hot = county_dataset.loc[index,f_columns].values # Columns for (gdd/days) for each crop, in each place
            nahot = pd.isna(hot)
            hot[nahot] = 0 # fill na with 0
            wt_feature = county_area_weights.values.copy() # Area weights for each county
            wt_feature[nahot]=0 # Ignore crops wtih NA temperature data
            
            if np.sum(wt_feature) > 0: # if total area is positive, then do the calcuation
                wt_feature = wt_feature/np.sum(wt_feature) # Normalize weights to one
                
                # Calculate extreme exposure
                row[feature + '_own_' + str(year)] = np.dot(county_area_weights,hot)
            
                # Calculate leave one out exposure of others
                crop_share = county_area.values / ta # this row's share of total US area, for each crop
                ec_reweight = crop_dataset[feature + '_maxOpt_' + str(year)].values / (1-crop_share) # Transformation of vector of crop-level shocks..
                ec_reweight = ec_reweight - (crop_share / (1-crop_share)) * hot #... minus transformed local shocks
                row[feature + '_loo_' + str(year)] = np.dot(wt_feature,ec_reweight)
                
                # Calculate leave state out exposure
                if do_state:
                    crop_share = this_state_info[1]
                    ec_reweight = crop_dataset[feature + '_maxOpt_' + str(year)].values / (1-crop_share)
                    ec_reweight = ec_reweight - (crop_share / (1-crop_share)) * this_state_info[2][year]
                    county_area_weights_lso = wt_feature
                    # Set weights to zero if the state was too large in a given crop, and note this in a new variable
                    row[feature +'_lso_omit_' + str(year)] = sum(crop_share > 1-EPS)
                    county_area_weights_lso[crop_share > 1-EPS] = 0 # Ignore crops that are not planted 
                    county_area_weights_lso = county_area_weights_lso / np.nansum(county_area_weights_lso)
                    row[feature + '_lso_' + str(year)] = np.nansum(county_area_weights_lso*ec_reweight)
                    
            else: # If there were no crops for which weather data were available, code as missing
                row[feature + '_own_' + str(year)] = np.nan
                row[feature + '_loo_' + str(year)] = np.nan
                if do_state:
                    row[feature + '_lso_' + str(year)] = np.nan
                    row[feature + '_lso_omit_' + str(year)] = np.nan
    return row
                
# Now loop over counties to perform leave one out calculations, etc
county_index = county_dataset.index 
county_list = list(county_index)
rows = []
i = 0
for county in county_list:
    rows.append(create_row(county))
    i += 1
    print('about ' + str(round(i/len(county_list),2)*100) + 'pct done')

county_out = pd.concat(rows,axis=1).transpose()
county_out.index = county_out.index.set_names(('STATE_FIPS','CNTY_FIPS'))


##
## Save output
##

county_out.to_csv(name_output+'.csv')

